import java.util.*;

public class _834 {
    static class Solution1{
        public int[] sumOfDistancesInTree(int n, int[][] edges) {
            //dfs 2次，
            //第一次从0开始，统计res[0],以及count[i];
            // if(edges.length == 0) return new int[n];
            Map<Integer, Set<Integer>> graph = new HashMap<>();
            for(int[] edge : edges){
                graph.computeIfAbsent(edge[0],k->new HashSet<>()).add(edge[1]);
                graph.computeIfAbsent(edge[1],k->new HashSet<>()).add(edge[0]);
            }
            int[] res = new int[n];
            int[] count = new int[n];
            System.out.println(graph);
            res[0] = computeCurDistanceAndCurCount(graph,count,0,-1,0);
            computeRes(graph,count,res,0,-1);
            return res;
        }

        public int computeCurDistanceAndCurCount(Map<Integer,Set<Integer>> graph,int[] count,int cur,int pre,int deep){
            count[cur] = 1;
            int res = deep;
            for(int child : graph.getOrDefault(cur, Collections.emptySet())){
                if(child == pre) continue;
                res += computeCurDistanceAndCurCount(graph,count,child,cur,deep+1);
                count[cur] += count[child];
            }
            return res;
        }

        public void computeRes(Map<Integer,Set<Integer>> graph,int[] count,int[] res,int cur,int pre){
            // if(pre == cur) return ;
            for(int child : graph.getOrDefault(cur,Collections.emptySet())){
                if(child == pre) continue;
                res[child]  = res[cur] + count.length - 2 * count[child];
                computeRes(graph,count,res,child,cur);
            }
        }
    }
}
